{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook will run a live demo on Jetson Nano using [JetCam](https://github.com/NVIDIA-AI-IOT/jetcam) to acquire images from the camera.  First,\n",
    "let's start the camera.  See the JetCam examples for details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from jetcam.csi_camera import CSICamera\n",
    "# from jetcam.usb_camera import USBCamera\n",
    "\n",
    "camera = CSICamera(width=224, height=224)\n",
    "# camera = USBCamera(width=224, height=224)\n",
    "\n",
    "camera.running = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's connect the camera's value to a widget to display."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from jetcam.utils import bgr8_to_jpeg\n",
    "import traitlets\n",
    "import ipywidgets\n",
    "\n",
    "image_w = ipywidgets.Image()\n",
    "\n",
    "traitlets.dlink((camera, 'value'), (image_w, 'value'), transform=bgr8_to_jpeg)\n",
    "\n",
    "display(image_w)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we'll load the TensorRT model.  (We assume you followed the conversion notebook and saved to the path ``resnet18_trt.pth``)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch2trt import TRTModule\n",
    "\n",
    "model_trt = TRTModule()\n",
    "model_trt.load_state_dict(torch.load('resnet18_trt.pth'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following function will be used to pre-process images from the camera"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2\n",
    "import numpy as np\n",
    "import torchvision\n",
    "\n",
    "device = torch.device('cuda')\n",
    "mean = 255.0 * np.array([0.485, 0.456, 0.406])\n",
    "stdev = 255.0 * np.array([0.229, 0.224, 0.225])\n",
    "\n",
    "normalize = torchvision.transforms.Normalize(mean, stdev)\n",
    "\n",
    "def preprocess(camera_value):\n",
    "    global device, normalize\n",
    "    x = camera_value\n",
    "    x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)\n",
    "    x = x.transpose((2, 0, 1))\n",
    "    x = torch.from_numpy(x).float()\n",
    "    x = normalize(x)\n",
    "    x = x.to(device)\n",
    "    x = x[None, ...]\n",
    "    return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This text area will be used to display the class predictions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "text = ipywidgets.Textarea()\n",
    "display(text)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We load the imagenet labels to associate the neural network output with a class name."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "with open('imagenet_labels.json', 'r') as f:\n",
    "    labels = json.load(f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we create our execution function, which we attach as a callback to the camera's ``value`` attribute.\n",
    "\n",
    "Whenever the camera's value is updated (which it will be for each frame, since we set ``camera.running = True``).  This function will be called\n",
    "describing how the value changed.  The new camera value will be stored in ``change['new']``."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def execute(change):\n",
    "    image = change['new']\n",
    "    output = model_trt(preprocess(image).half()).detach().cpu().numpy().flatten()\n",
    "    idx = output.argmax()\n",
    "    text.value = labels[idx]\n",
    "\n",
    "camera.observe(execute, names='value')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
